{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F\n",
    "from torchinfo import summary\n",
    "\n",
    "class SingleInputNet(nn.Module):\n",
    "    \"\"\" Simple CNN model. \"\"\"\n",
    "    def __init__(self) -> None:\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
    "        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
    "        self.conv2_drop = nn.Dropout2d(0.3)\n",
    "        self.fc1 = nn.Linear(320, 50)\n",
    "        self.fc2 = nn.Linear(50, 10)\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
    "        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
    "        x = x.view(-1, 320)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=1)\n",
    "\n",
    "model = SingleInputNet()\n",
    "input_size = (2, 1, 28, 28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "==========================================================================================\n",
       "Layer (type:depth-idx)                   Output Shape              Param #\n",
       "==========================================================================================\n",
       "├─Conv2d: 1-1                            [2, 10, 24, 24]           260\n",
       "├─Conv2d: 1-2                            [2, 20, 8, 8]             5,020\n",
       "├─Dropout2d: 1-3                         [2, 20, 8, 8]             --\n",
       "├─Linear: 1-4                            [2, 50]                   16,050\n",
       "├─Linear: 1-5                            [2, 10]                   510\n",
       "==========================================================================================\n",
       "Total params: 21,840\n",
       "Trainable params: 21,840\n",
       "Non-trainable params: 0\n",
       "Total mult-adds (M): 0.96\n",
       "==========================================================================================\n",
       "Input size (MB): 0.01\n",
       "Forward/backward pass size (MB): 0.11\n",
       "Params size (MB): 0.09\n",
       "Estimated Total Size (MB): 0.21\n",
       "=========================================================================================="
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summary(model, input_size=input_size, verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "==========================================================================================\n",
       "Layer (type:depth-idx)                   Output Shape              Param #\n",
       "==========================================================================================\n",
       "├─Conv2d: 1-1                            [2, 10, 24, 24]           260\n",
       "├─Conv2d: 1-2                            [2, 20, 8, 8]             5,020\n",
       "├─Dropout2d: 1-3                         [2, 20, 8, 8]             --\n",
       "├─Linear: 1-4                            [2, 50]                   16,050\n",
       "├─Linear: 1-5                            [2, 10]                   510\n",
       "==========================================================================================\n",
       "Total params: 21,840\n",
       "Trainable params: 21,840\n",
       "Non-trainable params: 0\n",
       "Total mult-adds (M): 0.96\n",
       "==========================================================================================\n",
       "Input size (MB): 0.01\n",
       "Forward/backward pass size (MB): 0.11\n",
       "Params size (MB): 0.09\n",
       "Estimated Total Size (MB): 0.21\n",
       "=========================================================================================="
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summary(model, input_size=input_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summary(model, input_size=input_size)\n",
    "2+2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==========================================================================================\n",
      "Layer (type:depth-idx)                   Output Shape              Param #\n",
      "==========================================================================================\n",
      "├─Conv2d: 1-1                            [2, 10, 24, 24]           260\n",
      "├─Conv2d: 1-2                            [2, 20, 8, 8]             5,020\n",
      "├─Dropout2d: 1-3                         [2, 20, 8, 8]             --\n",
      "├─Linear: 1-4                            [2, 50]                   16,050\n",
      "├─Linear: 1-5                            [2, 10]                   510\n",
      "==========================================================================================\n",
      "Total params: 21,840\n",
      "Trainable params: 21,840\n",
      "Non-trainable params: 0\n",
      "Total mult-adds (M): 0.96\n",
      "==========================================================================================\n",
      "Input size (MB): 0.01\n",
      "Forward/backward pass size (MB): 0.11\n",
      "Params size (MB): 0.09\n",
      "Estimated Total Size (MB): 0.21\n",
      "==========================================================================================\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(summary(model, input_size=input_size))\n",
    "2+2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
